import io
import pickle
import random
import sys
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)

from .utils import PipelineStage

T = TypeVar('T')
S = TypeVar('S')
Handler = Callable[[Exception], bool]
Sample = Dict[str, Any]
Decoder = List[Tuple[str, Callable]]

# Classes
class FilterFunction:
    f: Callable
    args: Tuple
    kw: Dict[str, Any]

    def __init__(self, f: Callable, *args: Any, **kw: Any) -> None: ...
    def __call__(self, data: Any) -> Any: ...
    def __str__(self) -> str: ...
    def __repr__(self) -> str: ...

class RestCurried:
    f: Callable

    def __init__(self, f: Callable) -> None: ...
    def __call__(self, *args: Any, **kw: Any) -> FilterFunction: ...

# Decorators
def pipelinefilter(f: Callable) -> RestCurried: ...

# Basic functions
def reraise_exception(exn: Exception) -> bool: ...
def identity(x: T) -> T: ...
def compose2(f: Callable[[T], S], g: Callable[[S], Any]) -> Callable[[T], Any]: ...
def compose(*args: Callable) -> Callable: ...
def pipeline(source: Any, *args: Callable) -> Any: ...
def getfirst(a: Dict[str, Any], keys: Union[str, List[str]], default: Any = None, missing_is_error: bool = True) -> Any: ...
def parse_field_spec(fields: Union[str, List[str]]) -> List[List[str]]: ...
def transform_with(sample: List[Any], transformers: Optional[List[Optional[Callable]]] = None) -> List[Any]: ...

# Pipeline functions
def _info(data: Iterable[Sample], fmt: Optional[str] = None, n: int = 3, every: int = -1, width: int = 50, stream: Any = sys.stderr, name: str = "") -> Iterator[Sample]: ...
info: RestCurried

def pick(buf: List[T], rng: random.Random) -> T: ...

def _shuffle(data: Iterable[T], bufsize: int = 1000, initial: int = 100, rng: Optional[random.Random] = None, seed: Optional[int] = None, handler: Optional[Handler] = None) -> Iterator[T]: ...
shuffle: RestCurried

class detshuffle(PipelineStage):
    bufsize: int
    initial: int
    seed: int
    epoch: int

    def __init__(self, bufsize: int = 1000, initial: int = 100, seed: int = 0, epoch: int = -1) -> None: ...
    def run(self, src: Iterable[T]) -> Iterator[T]: ...

def _select(data: Iterable[T], predicate: Callable[[T], bool]) -> Iterator[T]: ...
select: RestCurried

def _log_keys(data: Iterable[Sample], logfile: Optional[str] = None) -> Iterator[Sample]: ...
log_keys: RestCurried

def _decode(data: Iterable[Sample], *args: Union[str, Callable], handler: Handler = reraise_exception, **kw: Any) -> Iterator[Sample]: ...
decode: RestCurried

def _map(data: Iterable[Any], f: Callable, handler: Handler = reraise_exception) -> Iterator[Any]: ...
map: RestCurried

def _rename(data: Iterable[Sample], handler: Handler = reraise_exception, keep: bool = True, **kw: str) -> Iterator[Sample]: ...
rename: RestCurried

def _associate(data: Iterable[Sample], associator: Union[Dict[str, Any], Callable[[str], Dict[str, Any]]], **kw: Any) -> Iterator[Sample]: ...
associate: RestCurried

def _map_dict(data: Iterable[Sample], handler: Handler = reraise_exception, **kw: Callable) -> Iterator[Sample]: ...
map_dict: RestCurried

def _to_tuple(data: Iterable[Sample], *args: str, handler: Handler = reraise_exception, missing_is_error: bool = True, none_is_error: Optional[bool] = None) -> Iterator[Tuple]: ...
to_tuple: RestCurried

def _map_tuple(data: Iterable[Tuple], *args: Optional[Callable], handler: Handler = reraise_exception) -> Iterator[Tuple]: ...
map_tuple: RestCurried

def combine_values(b: List[Any], combine_tensors: bool = True, combine_scalars: bool = True) -> Any: ...
def tuple2dict(l: Union[Dict[int, Any], Tuple]) -> Dict[int, Any]: ...
def dict2tuple(d: Dict[int, Any]) -> Tuple: ...
def default_collation_fn(samples: List[Union[Dict[str, Any], Tuple]], combine_tensors: bool = True, combine_scalars: bool = True) -> Union[Dict[str, Any], Tuple]: ...

def _batched(data: Iterable[T], batchsize: int = 20, collation_fn: Optional[Callable] = default_collation_fn, partial: bool = True) -> Iterator[List[T]]: ...
batched: RestCurried

def _unlisted(data: Iterable[List[T]]) -> Iterator[T]: ...
unlisted: RestCurried

def _unbatched(data: Iterable[Union[Tuple, Dict[str, Any]]]) -> Iterator[Union[Tuple, Dict[str, Any]]]: ...
unbatched: RestCurried

def _rsample(data: Iterable[T], p: float = 0.5) -> Iterator[T]: ...
rsample: RestCurried

slice: RestCurried

def _extract_keys(source: Iterable[Sample], *patterns: str, duplicate_is_error: bool = True, ignore_missing: bool = False) -> Iterator[Tuple]: ...
extract_keys: RestCurried

def _rename_keys(source: Iterable[Sample], *args: Tuple[str, str], keep_unselected: bool = False, must_match: bool = True, duplicate_is_error: bool = True, **kw: str) -> Iterator[Sample]: ...
rename_keys: RestCurried

def decode_bin(stream: io.IOBase) -> bytes: ...
def decode_text(stream: io.IOBase) -> str: ...
def decode_pickle(stream: io.IOBase) -> Any: ...

default_decoders: List[Tuple[str, Callable]]

def find_decoder(decoders: List[Tuple[str, Callable]], path: str) -> Optional[Callable]: ...

def _xdecode(source: Iterable[Sample], *args: Tuple[str, Callable], must_decode: bool = True, defaults: List[Tuple[str, Callable]] = default_decoders, **kw: Callable) -> Iterator[Sample]: ...
xdecode: RestCurried

class Cached(PipelineStage):
    cached: Optional[List]
    temp: List

    def __init__(self) -> None: ...
    def run(self, source: Iterable[T]) -> Iterator[T]: ...

class LMDBCached(PipelineStage):
    db: Any
    pickler: Any
    chunksize: int

    def __init__(self, fname: str, map_size: float = 1e12, pickler: Any = pickle, chunksize: int = 500) -> None: ...
    def is_complete(self) -> bool: ...
    def add_samples(self, samples: Iterable[Tuple[str, Any]]) -> None: ...
    def run(self, source: Iterable[Sample]) -> Iterator[Sample]: ...
